-
Notifications
You must be signed in to change notification settings - Fork 1.9k
feature: unify new_tokens format sample state to trtllm samper tokens format #5513
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature: unify new_tokens format sample state to trtllm samper tokens format #5513
Conversation
…kens format Signed-off-by: Netanel Haber <[email protected]>
|
PR_Github #10145 [ skip ] triggered by Bot |
|
PR_Github #10145 [ skip ] completed with state |
|
Please ignore the skip, I triggered it by mistake on this PR. |
6874175 to
035e67a
Compare
Signed-off-by: Netanel Haber <[email protected]> minimize diff Signed-off-by: Netanel Haber <[email protected]> minimize diff Signed-off-by: Netanel Haber <[email protected]>
6397d52 to
84138c6
Compare
…_with_trtllm_sampler_sample_state Signed-off-by: Netanel Haber <[email protected]>
|
/bot run --disable-fail-fast |
|
PR_Github #10242 [ run ] triggered by Bot |
|
PR_Github #10242 [ run ] completed with state |
wili-65535
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work for simplifying of the samplers! LGTM on my side.
…_with_trtllm_sampler_sample_state Signed-off-by: Netanel Haber <[email protected]>
…sampling Signed-off-by: Netanel Haber <[email protected]>
a68c9fd to
051fe4a
Compare
…sampling Signed-off-by: Netanel Haber <[email protected]>
|
/bot run --disable-fail-fast |
|
PR_Github #10369 [ run ] triggered by Bot |
|
PR_Github #10369 [ run ] completed with state |
dcampora
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving, as perf issue is now fixed.
suyoggupta
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for AD changes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR re-applies previously reverted speculative decoding changes and fixes a performance regression in TorchSampler by unifying the new_tokens state format and refactoring sampler interfaces across the codebase.
- Refactored
get_spec_decoderto acceptTorchSampler.Argsand integratedTorchSamplerin speculative modes. - Overhauled
TorchSamplerAPI: introducedArgs/Storedataclasses, generic sampling helpers, and unifiedsample_async/update_requests. - Removed legacy sampler classes (
Eagle3Sampler,Eagle3Decoder), updated resource managers and scheduler to useall_requests.
Reviewed Changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| tensorrt_llm/_torch/speculative/utils.py | Updated get_spec_decoder signature and error handling for unsupported modes. |
| tensorrt_llm/_torch/speculative/mtp.py | Adapted MTPSampler to new TorchSampler.Args and simplified stop‐criteria calls. |
| tensorrt_llm/_torch/speculative/eagle3.py | Removed legacy Eagle3 sampler classes, added Eagle3OneModelSampler. |
| tensorrt_llm/_torch/pyexecutor/seq_slot_manager.py | Switched loops to use scheduled_batch.all_requests(). |
| tensorrt_llm/_torch/pyexecutor/scheduler.py | Simplified all_requests to return a list instead of chain. |
| tensorrt_llm/_torch/pyexecutor/sampler.py | Major refactor of TorchSampler: new dataclasses, unified sampling functions, updated state. |
| tensorrt_llm/_torch/pyexecutor/py_executor.py | Propagated max_num_sequences, integrated SeqSlotManager, adjusted logit fields. |
| tensorrt_llm/_torch/pyexecutor/model_engine.py | Updated batch‐index logic (py_batch_idx) and input preparation to new sampler format. |
| tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py | Updated TorchSampler instantiation to use Args and added SeqSlotManager. |
Comments suppressed due to low confidence (3)
tensorrt_llm/_torch/pyexecutor/sampler.py:98
- Add unit tests for
top_k_sampling_batch,top_p_sampling_batch, and the genericsamplepipeline to validate sampling distributions, edge cases (e.g., top_k=1, top_p=0.0), and correct handling of tensor dimensions.
def top_k_sampling_batch(logits, top_k=50):
tensorrt_llm/_torch/pyexecutor/sampler.py:180
- [nitpick] Add a docstring explaining this helper's purpose, the expected format of
strategy,logits, and what is returned (next_tokens and softmax probabilities).
def sample(strategy: Strategy, logits: torch.Tensor):
tensorrt_llm/_torch/speculative/utils.py:113
- [nitpick] Document this exception in the
get_spec_decoderdocstring so callers know it will raise for unknown modes, or consider returningNoneto match previous behavior if that was expected.
f"Unsupported speculative decoding mode: {spec_config.spec_dec_mode}")
| @property | ||
| def all_requests(self) -> chain[LlmRequest]: | ||
| return chain(self.context_requests, self.generation_requests) | ||
| def all_requests(self) -> list[LlmRequest]: |
Copilot
AI
Jun 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Consider returning an Iterable[LlmRequest] or Sequence[LlmRequest] instead of forcing a new list allocation on each call, or change the return annotation to list explicitly to reflect that behavior.
| def all_requests(self) -> list[LlmRequest]: | |
| def all_requests(self) -> Sequence[LlmRequest]: |
| new_tokens, | ||
| gen_logits_host=gen_logits_host, | ||
| log_probs_host=log_probs_host) | ||
| new_tokens_host = new_tokens.to(device="cpu", non_blocking=True) |
Copilot
AI
Jun 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Transferring the entire new_tokens tensor to CPU each iteration can be costly. If only a subset of slots is active, consider slicing new_tokens to only copy relevant indices and reduce data movement overhead.
…state to trtllm samper tokens format (NVIDIA#5513) 58a8a8f - these changes were previously merged to main here. 6aef149 - the changes were temporarily reverted in main, due to a significant perf regression in models using the TorchSampler (observed by @byshiue). This PR is meant to re-merge these changes along with a fix to prevent the regression. The first commit of this PR is actually just the reverted revert - filter it out of the changes to see previously unmerged changes. Signed-off-by: Netanel Haber <[email protected]>
…state to trtllm samper tokens format (NVIDIA#5513) 58a8a8f - these changes were previously merged to main here. 6aef149 - the changes were temporarily reverted in main, due to a significant perf regression in models using the TorchSampler (observed by @byshiue). This PR is meant to re-merge these changes along with a fix to prevent the regression. The first commit of this PR is actually just the reverted revert - filter it out of the changes to see previously unmerged changes. Signed-off-by: Netanel Haber <[email protected]>
…state to trtllm samper tokens format (NVIDIA#5513) 58a8a8f - these changes were previously merged to main here. 6aef149 - the changes were temporarily reverted in main, due to a significant perf regression in models using the TorchSampler (observed by @byshiue). This PR is meant to re-merge these changes along with a fix to prevent the regression. The first commit of this PR is actually just the reverted revert - filter it out of the changes to see previously unmerged changes. Signed-off-by: Netanel Haber <[email protected]>
…state to trtllm samper tokens format (NVIDIA#5513) 58a8a8f - these changes were previously merged to main here. 6aef149 - the changes were temporarily reverted in main, due to a significant perf regression in models using the TorchSampler (observed by @byshiue). This PR is meant to re-merge these changes along with a fix to prevent the regression. The first commit of this PR is actually just the reverted revert - filter it out of the changes to see previously unmerged changes. Signed-off-by: Netanel Haber <[email protected]>
…state to trtllm samper tokens format (NVIDIA#5513) 58a8a8f - these changes were previously merged to main here. 6aef149 - the changes were temporarily reverted in main, due to a significant perf regression in models using the TorchSampler (observed by @byshiue). This PR is meant to re-merge these changes along with a fix to prevent the regression. The first commit of this PR is actually just the reverted revert - filter it out of the changes to see previously unmerged changes. Signed-off-by: Netanel Haber <[email protected]>
…state to trtllm samper tokens format (NVIDIA#5513) 58a8a8f - these changes were previously merged to main here. 6aef149 - the changes were temporarily reverted in main, due to a significant perf regression in models using the TorchSampler (observed by @byshiue). This PR is meant to re-merge these changes along with a fix to prevent the regression. The first commit of this PR is actually just the reverted revert - filter it out of the changes to see previously unmerged changes. Signed-off-by: Netanel Haber <[email protected]>
…state to trtllm samper tokens format (NVIDIA#5513) 58a8a8f - these changes were previously merged to main here. 6aef149 - the changes were temporarily reverted in main, due to a significant perf regression in models using the TorchSampler (observed by @byshiue). This PR is meant to re-merge these changes along with a fix to prevent the regression. The first commit of this PR is actually just the reverted revert - filter it out of the changes to see previously unmerged changes. Signed-off-by: Netanel Haber <[email protected]>
…state to trtllm samper tokens format (NVIDIA#5513) 58a8a8f - these changes were previously merged to main here. 6aef149 - the changes were temporarily reverted in main, due to a significant perf regression in models using the TorchSampler (observed by @byshiue). This PR is meant to re-merge these changes along with a fix to prevent the regression. The first commit of this PR is actually just the reverted revert - filter it out of the changes to see previously unmerged changes. Signed-off-by: Netanel Haber <[email protected]>
…state to trtllm samper tokens format (NVIDIA#5513) 58a8a8f - these changes were previously merged to main here. 6aef149 - the changes were temporarily reverted in main, due to a significant perf regression in models using the TorchSampler (observed by @byshiue). This PR is meant to re-merge these changes along with a fix to prevent the regression. The first commit of this PR is actually just the reverted revert - filter it out of the changes to see previously unmerged changes. Signed-off-by: Netanel Haber <[email protected]>
58a8a8f - these changes were previously merged to main here.
6aef149 - the changes were temporarily reverted in main, due to a significant perf regression in models using the
TorchSampler(observed by @byshiue).This PR is meant to re-merge these changes along with a fix to prevent the regression.
The first commit of this PR is actually just the reverted revert - filter it out of the changes to see previously unmerged changes.